// Copyright 2024 The Google Research Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.


syntax = "proto2";

package automl_zero;



// The tasks to evaluate.
message TaskCollection {
  repeated TaskSpec tasks = 1;  // Non-empty.
}

enum EvalType {
  INVALID_EVAL_TYPE = 0;
  RMS_ERROR = 1;
  ACCURACY = 4;
}

// Encodes information about a task of a given kind.
message TaskSpec {
  // Size of each features vector. This also sets the size of all vectors and
  // matrices in the memory.
  optional int32 features_size = 13;

  // Number of unique training examples.
  optional int32 num_train_examples = 1;  // Required.

  // Number of times the training examples will be repeated to mimic multiple
  // epochs over a fixed training set.
  optional int32 num_train_epochs = 21 [default = 1];

  optional int32 num_valid_examples = 2;  // Required.

  // Number of tasks with this specification.
  optional int32 num_tasks = 3;

  // Seeds for the features. If data_seeds have n elements (n > 0), they will
  // be used as the seeds for the first n tasks, the seeds for the rest will
  // be incrementing from the last seed in data_seeds. If data_seeds is empty,
  // the default seeds will be used. See FillTasks function in
  // task_util.cc for more details.
  // TODO(crazydonkey): make sure the random seed is never 0.
  repeated uint32 data_seeds = 4;

  // Seeds for the parameters that determine the labels function. Same rules
  // as for data_seeds apply.
  // TODO(crazydonkey): make sure the random seed is never 0.
  repeated uint32 param_seeds = 5;

  // See task_type case for allowed EvalType values.
  optional EvalType eval_type = 28;  // Required.

  oneof task_type {
    // Linear regression task.
    ScalarLinearRegressionTaskSpec scalar_linear_regression_task = 6;

    // Non-linear regression task generated by 2 layer NN.
    Scalar2LayerNNRegressionTaskSpec scalar_2layer_nn_regression_task = 7;

    // Binary classification task generated by randomly projecting an
    // MNIST and CIFAR-10 to lower dimensions.
    ProjectedBinaryClassificationTask projected_binary_classification_task = 24;

    // Useful tasks for tests.
    UnitTestFixedTask unit_test_fixed_task = 40;
    UnitTestZerosTaskSpec unit_test_zeros_task = 45;
    UnitTestOnesTaskSpec unit_test_ones_task = 46;
    UnitTestIncrementTaskSpec unit_test_increment_task = 47;
  }

  // Used for final evaluation.
  optional int32 num_test_examples = 18;
}

enum ActivationType {
  RELU = 0;
  TANH = 1;
}

message ScalarLinearRegressionTaskSpec {}

message Scalar2LayerNNRegressionTaskSpec {}

// A projected binary classification task. These use pre-generated datasets.
// The following TaskSpec fields are restricted to the given values:
//   eval_type: ACCURACY.
//   num_train_examples: the value should be an integer in (0, 8000]
//   num_valid_examples: the value should be an integer in (0, 1000]
//   num_test_examples: the value should be an integer in (0, 1000]
//   param_seeds: the param_seeds are not used so doesn't matter.
// Below are the supported choices for dataset_name, features_size,
// min_supported_data_seed, max_supported_data_seed and use_downsampling:
// |dataset_name|features_size|min/max_supported_data_seed|
// |------------|-------------|---------------------------|
// |mnist       |16           |0 / 100                    |
// |cifar10     |16           |0 / 100                    |
//
// Meta-train / meta-validation / meta-test split:
// Since some positive-negative pairs are heldout,
// you can use all the seeds during search (meta-train) and use the heldout
// pairs in model selection and evaluation (meta-validation and meta-test).
// Among all 45 possible pairs, we recommend that the following
// 9 randomly selected pairs be held out for meta-validation and meta-test:
// (4, 6), (3, 5), (8, 9), (3, 8), (0, 9), (2, 9), (1, 8), (3, 6), (0, 5).
// If transferring to the original feature size is used as final evaluation
// (meta-test), you can use all the heldout pairs as meta-validation.
// If no transferring is used, you can use the first 4 pairs as
// meta-validation and the rest 5 pairs as meta-test.
message ProjectedBinaryClassificationTask {
  // Below are the IDs for the positive and negative classes, you should
  // either specify:
  // (1) both of them, in this case, the given positive and negative classes
  // will be used;
  // (2) none of them, in this case, the positive and negative classes will
  // be randomly chosen based on the data_seed.
  //
  // Both values should be integers in [0, 9] with the `positive_class` smaller
  // than the `negative_class`.
  optional int32 positive_class = 1;
  optional int32 negative_class = 2;

  // Name to specify the dataset to use, currently supporting "mnist" and
  // "cifar10".
  optional string dataset_name = 3;

  // There are two possible sources to get the projected data:
  // (1) the data is saved in this proto, i.e., all the features and
  //     labels of the train/validation/test set are saved in the
  //     `dataset` field.
  // (2) If `path` is set, it will be used as the path to the folder
  // containing all the serialized data.
  oneof task_source {
    string path = 4;
    ScalarLabelDataset dataset = 5;
  }

  // Pairs to hold out when randomizing the dataset.
  repeated ClassPair held_out_pairs = 6;

  // Minimum (incl.) and maximum (excl.) data seeds supported in the sstable
  // that saves the dumped projected dataset.
  //
  // Only seeds in the range specified in the table above are supported.
  // The seed is obtained by mapping `data_seed` into the range with
  //     seed = (data_seed % (max_supported_data_seed-min_supported_data_seed) +
  //             min_supported_data_seed)
  // (1) when the dataset is not randomized, the specified `positive_class` and
  // `negative_class` will be used;
  // (2) when the dataset is randomized, i.e., when `positive_class` and
  // `negative_classess` are not set, the `data_seed` is also used
  // to randomly select a the positive and negative classes.
  // TODO(crazydonkey): make sure the random seed is never 0.
  optional int32 min_supported_data_seed = 7 [default = 0];
  optional int32 max_supported_data_seed = 8 [default = 10];
}

message ClassPair {
  optional int32 positive_class = 1;
  optional int32 negative_class = 2;
}

message ScalarLabelDataset {
  // Training, validation and test examples.
  repeated FeatureVector train_features = 1;
  repeated float train_labels = 2;
  repeated FeatureVector valid_features = 3;
  repeated float valid_labels = 4;
  repeated FeatureVector test_features = 5;
  repeated float test_labels = 6;
}

message FeatureVector {
  repeated float features = 1;
}

// A task where the data is specified explicitly during construction.
// Useful for unit tests.
message UnitTestFixedTask {
  // Training, validation and test examples.
  repeated UnitTestFixedTaskVector train_features = 1;
  repeated UnitTestFixedTaskVector train_labels = 2;
  repeated UnitTestFixedTaskVector valid_features = 3;
  repeated UnitTestFixedTaskVector valid_labels = 4;
  repeated UnitTestFixedTaskVector test_features = 5;
  repeated UnitTestFixedTaskVector test_labels = 6;

  reserved 7;
}

message UnitTestFixedTaskVector {
  repeated double elements = 1;
}

message UnitTestZerosTaskSpec {}

message UnitTestOnesTaskSpec {}

message UnitTestIncrementTaskSpec {
  optional double increment = 1 [default = 1.0];
}
